Skip to content

Conversation

@Dokujaa
Copy link
Contributor

@Dokujaa Dokujaa commented Jul 18, 2025

  • Add OAuth2 token caching functions to async_cache.py and cache.py
  • Create async_oauth_token_cache with 55-minute TTL (5-min safety buffer)
  • Update vertex_authentication method to check cache first before token refresh
  • Add OAuth2 token cache statistics tracking
  • Cache key format: token:{api_key}
  • Performance optimization: reduce unnecessary token refresh calls

Let me know if this implementation is sufficient. All the test cases that I ran worked.

@lingtonglu lingtonglu mentioned this pull request Jul 18, 2025
ttl_seconds=3600
) # 1-hour TTL
# OAuth2 token caching (55-min TTL with 5-min safety buffer before 1-hour token expiry)
async_oauth_token_cache: "AsyncCache" = _AsyncBackend(ttl_seconds=3300) # 55-min TTL
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's be conservative about this and use 40 mins TTL here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 55 minutes is more reasonable if we add the refresh token mechanism

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've reconsidered the TTL setup and now believe we should avoid using a fixed TTL in favor of storing an explicit expiration timestamp with each token. This way, we can validate tokens by comparing the stored expires_at against the current time.

Using both TTL and the token’s own expiration can introduce unnecessary complexity and inconsistencies. For example, if the Redis TTL is set to 40 minutes but the token itself is valid for 60 minutes, we’d lose the cache prematurely at 41 minutes, causing an unnecessary token refresh. Conversely, if the token expires at 60 minutes but the cache TTL is longer, we might end up using an invalid token that appears “alive” in cache.

By tracking expires_at explicitly, we can avoid these edge cases and handle token refresh more reliably through expiration checks and error-based retries.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is good. One thing I am concerned about is cleaning up old tokens. Could cause memory leaks.

If we do this:
async_oauth_token_cache: "AsyncCache" = _AsyncBackend(ttl_seconds=None) # No TTL

Some potential Problems:

  1. Expired tokens accumulate forever → memory leak
  2. Cache grows indefinitely
  3. No automatic cleanup mechanism

I think we can do this though:

In vertex_adapter.py

token_data = {
    "access_token": credentials.token,
    "token_type": "Bearer",
    "expires_at": credentials.expiry.timestamp(),  # Unix timestamp
    "scope": "https://www.googleapis.com/auth/cloud-platform",
    "cached_at": time.time(),  # For debugging
    "provider": "vertex"  # Helpful for multi-provider systems
}

using expires_at for cleaning up tokens:

async_oauth_token_cache: "AsyncCache" = _AsyncBackend(ttl_seconds=None)

async def get_cached_oauth_token_async(api_key: str) -> dict[str, Any] | None:
    if not api_key:
        return None
    import hashlib, time
    cache_key = f"token:{hashlib.sha256(api_key.encode()).hexdigest()}"
    cached_data = await async_oauth_token_cache.get(cache_key)
    if not cached_data:
        return None
    expires_at = cached_data.get("expires_at")
    if not expires_at:
        await async_oauth_token_cache.delete(cache_key)
        return None
    current_time = time.time()
    if expires_at <= current_time:
        await async_oauth_token_cache.delete(cache_key)
        await _opportunistic_cleanup(current_time, max_items=2)
        return None
    return cached_data

async def _opportunistic_cleanup(current_time: float, max_items: int = 2):
    cleaned = 0
    if hasattr(async_oauth_token_cache, "cache"):
        for key, value in list(async_oauth_token_cache.cache.items()):
            if cleaned >= max_items:
                break
            if key.startswith("token:"):
                expires_at = value.get("expires_at")
                if expires_at and expires_at <= current_time:
                    await async_oauth_token_cache.delete(key)
                    cleaned += 1
    elif hasattr(async_oauth_token_cache, "client"):
        try:
            import os
            pattern = f"{os.getenv('REDIS_PREFIX', 'forge')}:token:*"
            async for redis_key in async_oauth_token_cache.client.scan_iter(match=pattern, count=10):
                if cleaned >= max_items:
                    break
                key_str = redis_key.decode() if isinstance(redis_key, bytes) else redis_key
                internal_key = key_str.split(":", 1)[-1]
                cached_data = await async_oauth_token_cache.get(internal_key)
                if cached_data:
                    expires_at = cached_data.get("expires_at")
                    if expires_at and expires_at <= current_time:
                        await async_oauth_token_cache.delete(internal_key)
                        cleaned += 1
        except Exception:
            pass

async def cache_oauth_token_async(api_key: str, token_data: dict[str, Any]) -> None:
    if not api_key or not token_data:
        return
    import hashlib
    cache_key = f"token:{hashlib.sha256(api_key.encode()).hexdigest()}"
    if "expires_at" not in token_data:
        logger.warning("OAuth token cached without expires_at - skipping")
        return
    await async_oauth_token_cache.set(cache_key, token_data)

@@ -118,6 +118,8 @@ def stats(self) -> dict[str, Any]:
# Expose the global cache instances
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're working on a PR to fully switch to async and would discard the sync cache very soon. We don't need to support this any more.

self.parse_api_key(api_key)

# check cache first for existing valid token
cached_token = await get_cached_oauth_token_async(api_key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually thought about cached oauth2 token while implementing this. My biggest concern is that if for some reason, this cached token is invalided before the expiry time. There is no way for users to manually invalidate the cache and trigger a new one token generation. They would have to wait for 40-50 minutes for the cache to expire.

And if we choose to run a simple connection test for the token before . The performance would degrade and fall back to the uncached version. Any suggestion?

Copy link
Contributor

@wilsonccccc wilsonccccc Jul 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe Vertex follows the standard OAuth 2.0 protocol, which returns two key tokens upon authorization:

  1. Access Token – Used to authenticate user access to resources (e.g., Vertex APIs). It typically has a short time-to-live (TTL), such as 1 hour.
  2. Refresh Token – Used solely to obtain a new access token once the original expires. It enables seamless long-term authentication without requiring the user to log in again.

The typical workflow looks like this:
[User Logs In / Grants Access] → [Authorization Server returns access_token + refresh_token] → access_token is used to call Vertex APIs → refresh_token is used to obtain a new access_token when needed.

So ideally the token logic would be:

def get_vertex_token(api_key):
    token_data = redis.get(f"vertex_token:{api_key}")
    if not token_data:
        raise Unauthorized("Vertex not bound")

    # if the access token is not expired
    if token_data["expires_at"] > current_timestamp():
        return token_data["access_token"]

    # if there has refresh_token, then refresh
    if "refresh_token" in token_data:
        new_token = refresh_vertex_token(token_data["refresh_token"])
        redis.set(f"vertex_token:{api_key}", new_token, ttl=3600)
        return new_token["access_token"]

    # can not refresh, not to re-login
    raise TokenExpired("Please reconnect your Vertex account")

As the following step in next PR, we should store the refresh token by fernet.encryption(refresh_token) into postgre db with a new table names something like: user_token etc. (Please create a new issues as the record if you guys agree my suggestion)

also, I suggest the access token cache json design should be also include 1. scope, 2. token_type ("Bearer" etc).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for this case, we can do straightforward and more easily, just always use user's JSON file to re-authenticate if the current access token is expired or cleanup, let's just forget about refresh token at this stage

"token": credentials.token,
"expiry": credentials.expiry.isoformat()
}
await cache_oauth_token_async(api_key, token_data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlike other providers, the api key for the oauth2 is a long json string. It's not suited for cache key. Let's use a hash function to generate the cache key

@wilsonccccc
Copy link
Contributor

Thanks so much for this optimization, very nice work overall

@wilsonccccc
Copy link
Contributor

I think this PR should be rebased on the database and async cache refactor from #16.
If you have time, feel free to take a look at that PR — it looks good to me. Once it's merged, please rebase your changes, and I’ll review your PR again.

@Dokujaa Dokujaa force-pushed the vertex_support branch 2 times, most recently from 04d6d4d to 0d20eaa Compare July 20, 2025 20:36
lingtonglu and others added 3 commits July 20, 2025 16:44
- Add OAuth2 token caching functions to async_cache.py and cache.py
- Create async_oauth_token_cache with 55-minute TTL (5-min safety buffer)
- Update vertex_authentication method to check cache first before token refresh
- Add OAuth2 token cache statistics tracking
- Cache key format: token:{api_key}
- Performance optimization: reduce unnecessary token refresh calls
  * Replace fixed TTL with token's native expires_at timestamp validation
  * Use SHA-256 hashed cache keys for long JSON service account credentials
  * Add opportunistic cleanup to prevent expired token memory leaks
  * Standardize token structure with access_token, expires_at, token_type fields
  * Simplify Vertex AI authentication using service account re-authentication
  * Move imports to top-level for better code quality

  Addresses PR feedback on cache key security, TTL complexity, and memory management.
@Dokujaa
Copy link
Contributor Author

Dokujaa commented Jul 20, 2025

I am going to close this PR and make a PR to the main branch, as there seems to be conflicts with the vertex_support branch.

@Dokujaa Dokujaa closed this Jul 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants